import copy
from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence, Any, Union

import ml_collections as mlc
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
from openfold.np.residue_constants import restypes
from openfold.data import (
    data_pipeline,
    feature_pipeline,
    mmcif_parsing,
    templates,
)
from openfold.utils.tensor_utils import dict_multimap
from openfold.utils.tensor_utils import (
    tensor_tree_map,
)


class OpenFoldSingleDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir: str,
                 alignment_dir: str,
                 template_mmcif_dir: str,
                 max_template_date: str,
                 config: mlc.ConfigDict,
                 chain_data_cache_path: Optional[str] = None,
                 kalign_binary_path: str = '/usr/bin/kalign',
                 max_template_hits: int = 4,
                 obsolete_pdbs_file_path: Optional[str] = None,
                 template_release_dates_cache_path: Optional[str] = None,
                 shuffle_top_k_prefiltered: Optional[int] = None,
                 treat_pdb_as_distillation: bool = True,
                 filter_path: Optional[str] = None,
                 mode: str = "train",
                 alignment_index: Optional[Any] = None,
                 _output_raw: bool = False,
                 _structure_index: Optional[Any] = None,
                 ):
        """
            Args:
                data_dir:
                    A path to a directory containing mmCIF files (in train
                    mode) or FASTA files (in inference mode).
                alignment_dir:
                    A path to a directory containing only data in the format 
                    output by an AlignmentRunner 
                    (defined in openfold.features.alignment_runner).
                    I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
                    or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
                    files.
                template_mmcif_dir:
                    Path to a directory containing template mmCIF files.
                config:
                    A dataset config object. See openfold.config
                chain_data_cache_path:
                    Path to cache of data_dir generated by
                    scripts/generate_chain_data_cache.py
                kalign_binary_path:
                    Path to kalign binary.
                max_template_hits:
                    An upper bound on how many templates are considered. During
                    training, the templates ultimately used are subsampled
                    from this total quantity.
                template_release_dates_cache_path:
                    Path to the output of scripts/generate_mmcif_cache.
                obsolete_pdbs_file_path:
                    Path to the file containing replacements for obsolete PDBs.
                shuffle_top_k_prefiltered:
                    Whether to uniformly shuffle the top k template hits before
                    parsing max_template_hits of them. Can be used to
                    approximate DeepMind's training-time template subsampling
                    scheme much more performantly.
                treat_pdb_as_distillation:
                    Whether to assume that .pdb files in the data_dir are from
                    the self-distillation set (and should be subjected to
                    special distillation set preprocessing steps).
                mode:
                    "train", "val", or "predict"
        """
        super(OpenFoldSingleDataset, self).__init__()
        self.data_dir = data_dir

        self.chain_data_cache = None
        if chain_data_cache_path is not None:
            with open(chain_data_cache_path, "r") as fp:
                self.chain_data_cache = json.load(fp)
            assert isinstance(self.chain_data_cache, dict)

        self.alignment_dir = alignment_dir
        self.config = config
        self.treat_pdb_as_distillation = treat_pdb_as_distillation
        self.mode = mode
        self.alignment_index = alignment_index
        self._output_raw = _output_raw
        self._structure_index = _structure_index

        self.supported_exts = [".cif", ".core", ".pdb"]

        valid_modes = ["train", "eval", "predict"]
        if mode not in valid_modes:
            raise ValueError(f'mode must be one of {valid_modes}')

        if template_release_dates_cache_path is None:
            logging.warning(
                "Template release dates cache does not exist. Remember to run "
                "scripts/generate_mmcif_cache.py before running OpenFold"
            )

        if alignment_index is not None:
            self._chain_ids = list(alignment_index.keys())
        else:
            self._chain_ids = list(os.listdir(alignment_dir))

        if filter_path is not None:
            with open(filter_path, "r") as f:
                chains_to_include = set([l.strip() for l in f.readlines()])

            self._chain_ids = [
                c for c in self._chain_ids if c in chains_to_include
            ]

        if self.chain_data_cache is not None:
            # Filter to include only chains where we have structure data
            # (entries in chain_data_cache)
            original_chain_ids = self._chain_ids
            self._chain_ids = [
                c for c in self._chain_ids if c in self.chain_data_cache
            ]
            if len(self._chain_ids) < len(original_chain_ids):
                missing = [
                    c for c in original_chain_ids
                    if c not in self.chain_data_cache
                ]
                max_to_print = 10
                missing_examples = ", ".join(missing[:max_to_print])
                if len(missing) > max_to_print:
                    missing_examples += ", ..."
                logging.warning(
                    "Removing %d alignment entries (%s) with no corresponding "
                    "entries in chain_data_cache (%s).",
                    len(missing),
                    missing_examples,
                    chain_data_cache_path)

        self._chain_id_to_idx_dict = {
            chain: i for i, chain in enumerate(self._chain_ids)
        }

        # If it's running template search for a monomer, then use hhsearch
        # as demonstrated in AlphaFold's run_alphafold.py code
        # https://github.com/deepmind/alphafold/blob/6c4d833fbd1c6b8e7c9a21dae5d4ada2ce777e10/run_alphafold.py#L462C1-L477
        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=template_mmcif_dir,
            max_template_date=max_template_date,
            max_hits=max_template_hits,
            kalign_binary_path=kalign_binary_path,
            release_dates_path=template_release_dates_cache_path,
            obsolete_pdbs_path=obsolete_pdbs_file_path,
            _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
        )

        self.data_pipeline = data_pipeline.DataPipeline(
            template_featurizer=template_featurizer,
        )

        if not self._output_raw:
            self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

    def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
        with open(path, 'r') as f:
            mmcif_string = f.read()

        mmcif_object = mmcif_parsing.parse(
            file_id=file_id, mmcif_string=mmcif_string
        )

        # Crash if an error is encountered. Any parsing errors should have
        # been dealt with at the alignment stage.
        if mmcif_object.mmcif_object is None:
            raise list(mmcif_object.errors.values())[0]

        mmcif_object = mmcif_object.mmcif_object

        data = self.data_pipeline.process_mmcif(
            mmcif=mmcif_object,
            alignment_dir=alignment_dir,
            chain_id=chain_id,
            alignment_index=alignment_index,
            seqemb_mode=self.config.seqemb_mode.enabled
        )

        return data

    def chain_id_to_idx(self, chain_id):
        return self._chain_id_to_idx_dict[chain_id]

    def idx_to_chain_id(self, idx):
        return self._chain_ids[idx]

    def __getitem__(self, idx):
        name = self.idx_to_chain_id(idx)
        alignment_dir = os.path.join(self.alignment_dir, name)

        alignment_index = None
        if self.alignment_index is not None:
            alignment_dir = self.alignment_dir
            alignment_index = self.alignment_index[name]

        if self.mode == 'train' or self.mode == 'eval':
            spl = name.rsplit('_', 1)
            if len(spl) == 2:
                file_id, chain_id = spl
            else:
                file_id, = spl
                chain_id = None

            path = os.path.join(self.data_dir, file_id)
            if self._structure_index is not None:
                structure_index_entry = self._structure_index[name]
                assert (len(structure_index_entry["files"]) == 1)
                filename, _, _ = structure_index_entry["files"][0]
                ext = os.path.splitext(filename)[1]
            else:
                ext = None
                for e in self.supported_exts:
                    if os.path.exists(path + e):
                        ext = e
                        break

                if ext is None:
                    raise ValueError("Invalid file type")

            path += ext
            if ext == ".cif":
                data = self._parse_mmcif(
                    path, file_id, chain_id, alignment_dir, alignment_index,
                )
            elif ext == ".core":
                data = self.data_pipeline.process_core(
                    path, alignment_dir, alignment_index,
                    seqemb_mode=self.config.seqemb_mode.enabled,
                )
            elif ext == ".pdb":
                structure_index = None
                if self._structure_index is not None:
                    structure_index = self._structure_index[name]
                data = self.data_pipeline.process_pdb(
                    pdb_path=path,
                    alignment_dir=alignment_dir,
                    is_distillation=self.treat_pdb_as_distillation,
                    chain_id=chain_id,
                    alignment_index=alignment_index,
                    _structure_index=structure_index,
                    seqemb_mode=self.config.seqemb_mode.enabled,
                )
            else:
                raise ValueError("Extension branch missing")
        else:
            path = os.path.join(name, name + ".fasta")
            data = self.data_pipeline.process_fasta(
                fasta_path=path,
                alignment_dir=alignment_dir,
                alignment_index=alignment_index,
                seqemb_mode=self.config.seqemb_mode.enabled,
            )

        if self._output_raw:
            return data

        feats = self.feature_pipeline.process_features(
            data, self.mode
        )

        feats["batch_idx"] = torch.tensor(
            [idx for _ in range(feats["aatype"].shape[-1])],
            dtype=torch.int64,
            device=feats["aatype"].device)

        return feats

    def __len__(self):
        return len(self._chain_ids)


class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir: str,
                 alignment_dir: str,
                 template_mmcif_dir: str,
                 max_template_date: str,
                 config: mlc.ConfigDict,
                 mmcif_data_cache_path: Optional[str] = None,
                 kalign_binary_path: str = '/usr/bin/kalign',
                 max_template_hits: int = 4,
                 obsolete_pdbs_file_path: Optional[str] = None,
                 template_release_dates_cache_path: Optional[str] = None,
                 shuffle_top_k_prefiltered: Optional[int] = None,
                 treat_pdb_as_distillation: bool = True,
                 filter_path: Optional[str] = None,
                 mode: str = "train",
                 alignment_index: Optional[Any] = None,
                 _output_raw: bool = False,
                 _structure_index: Optional[Any] = None,
                 ):
        """
        This class check each individual PDB ID and return its chain(s) features/ground truth 
            Args:
                data_dir:
                    A path to a directory containing mmCIF files (in train
                    mode) or FASTA files (in inference mode).
                alignment_dir:
                    A path to a directory containing only data in the format 
                    output by an AlignmentRunner 
                    (defined in openfold.features.alignment_runner).
                    I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
                    or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
                    files.
                template_mmcif_dir:
                    Path to a directory containing template mmCIF files.
                config:
                    A dataset config object. See openfold.config
                mmcif_data_cache_path:
                    Path to cache of all mmcifs files generated by
                    scripts/generate_mmcif_cache.py It should be a json file which records
                    what PDB ID contains which chain(s)
                kalign_binary_path:
                    Path to kalign binary.
                max_template_hits:
                    An upper bound on how many templates are considered. During
                    training, the templates ultimately used are subsampled
                    from this total quantity.
                template_release_dates_cache_path:
                    Path to the output of scripts/generate_mmcif_cache.
                obsolete_pdbs_file_path:
                    Path to the file containing replacements for obsolete PDBs.
                shuffle_top_k_prefiltered:
                    Whether to uniformly shuffle the top k template hits before
                    parsing max_template_hits of them. Can be used to
                    approximate DeepMind's training-time template subsampling
                    scheme much more performantly.
                treat_pdb_as_distillation:
                    Whether to assume that .pdb files in the data_dir are from
                    the self-distillation set (and should be subjected to
                    special distillation set preprocessing steps).
                mode:
                    "train", "val", or "predict"
        """
        super(OpenFoldSingleMultimerDataset, self).__init__()
        self.data_dir = data_dir
        self.mmcif_data_cache_path = mmcif_data_cache_path

        if self.mmcif_data_cache_path is not None:
            with open(self.mmcif_data_cache_path, "r") as infile:
                self.mmcif_data_cache = json.load(infile)
            assert isinstance(self.mmcif_data_cache, dict)

        self.alignment_dir = alignment_dir
        self.config = config
        self.treat_pdb_as_distillation = treat_pdb_as_distillation
        self.mode = mode
        self.alignment_index = alignment_index
        self._output_raw = _output_raw
        self._structure_index = _structure_index

        self.supported_exts = [".cif", ".core", ".pdb"]

        valid_modes = ["train", "eval", "predict"]
        if mode not in valid_modes:
            raise ValueError(f'mode must be one of {valid_modes}')

        if template_release_dates_cache_path is None:
            logging.warning(
                "Template release dates cache does not exist. Remember to run "
                "scripts/generate_mmcif_cache.py before running OpenFold"
            )

        if self.mmcif_data_cache_path is not None:
            self._mmcifs = list(self.mmcif_data_cache.keys())
        elif self.alignment_index is not None:
            self._mmcifs = [i.split("_")[0] for i in list(alignment_index.keys())]
        elif self.alignment_dir is not None:
            self._mmcifs = [i.split("_")[0] for i in os.listdir(self.alignment_dir)]
        else:
            raise ValueError("You must provide at least one of the mmcif_data_cache or alignment_dir")

        if filter_path is not None:
            with open(filter_path, "r") as f:
                mmcifs_to_include = set([l.strip() for l in f.readlines()])

            self._mmcifs = [
                m for m in self._mmcifs if m in mmcifs_to_include
            ]

        self._mmcif_id_to_idx_dict = {
            mmcif: i for i, mmcif in enumerate(self._mmcifs)
        }

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=template_mmcif_dir,
            max_template_date=max_template_date,
            max_hits=max_template_hits,
            kalign_binary_path=kalign_binary_path,
            release_dates_path=template_release_dates_cache_path,
            obsolete_pdbs_path=obsolete_pdbs_file_path,
            _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
        )

        data_processor = data_pipeline.DataPipeline(
            template_featurizer=template_featurizer,
        )
        self.data_pipeline = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor
        )
        self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

    def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
        with open(path, 'r') as f:
            mmcif_string = f.read()

        mmcif_object = mmcif_parsing.parse(
            file_id=file_id, mmcif_string=mmcif_string
        )

        # Crash if an error is encountered. Any parsing errors should have
        # been dealt with at the alignment stage.
        if mmcif_object.mmcif_object is None:
            raise list(mmcif_object.errors.values())[0]

        mmcif_object = mmcif_object.mmcif_object

        data = self.data_pipeline.process_mmcif(
            mmcif=mmcif_object,
            alignment_dir=alignment_dir,
            alignment_index=alignment_index
        )

        return data

    def mmcif_id_to_idx(self, mmcif_id):
        return self._mmcif_id_to_idx_dict[mmcif_id]

    def idx_to_mmcif_id(self, idx):
        return self._mmcifs[idx]

    def __getitem__(self, idx):
        mmcif_id = self.idx_to_mmcif_id(idx)

        alignment_index = None
        if self.alignment_index is not None:
            alignment_index = {k: v for k, v in self.alignment_index.items()
                               if f'{mmcif_id}_' in k}

        if self.mode == 'train' or self.mode == 'eval':
            path = os.path.join(self.data_dir, f"{mmcif_id}")
            ext = None
            for e in self.supported_exts:
                if os.path.exists(path + e):
                    ext = e
                    break

            if ext is None:
                raise ValueError("Invalid file type")

            # TODO: Add pdb and core exts to data_pipeline for multimer
            path += ext
            if ext == ".cif":
                data = self._parse_mmcif(
                    path, mmcif_id, self.alignment_dir, alignment_index,
                )
            else:
                raise ValueError("Extension branch missing")
        else:
            path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
            data = self.data_pipeline.process_fasta(
                fasta_path=path,
                alignment_dir=self.alignment_dir,
                alignment_index=alignment_index
            )

        if self._output_raw:
            return data

        # process all_chain_features
        data = self.feature_pipeline.process_features(data,
                                                      mode=self.mode,
                                                      is_multimer=True)

        # if it's inference mode, only need all_chain_features
        data["batch_idx"] = torch.tensor(
            [idx for _ in range(data["aatype"].shape[-1])],
            dtype=torch.int64,
            device=data["aatype"].device)

        return data

    def __len__(self):
        return len(self._mmcifs)


def resolution_filter(resolution: int, max_resolution: float) -> bool:
    """Check that the resolution is <= max_resolution permitted"""
    return resolution is not None and resolution <= max_resolution


def aa_count_filter(seqs: list, max_single_aa_prop: float) -> bool:
    """Check if any single amino acid accounts for more than max_single_aa_prop percent of the sequence(s)"""
    counts = {}
    for seq in seqs:
        for aa in seq:
            counts.setdefault(aa, 0)
            if aa not in restypes:
                return False
            else:
                counts[aa] += 1

    total_len = sum([len(i) for i in seqs])
    largest_aa_count = max(counts.values())
    largest_single_aa_prop = largest_aa_count / total_len
    return largest_single_aa_prop <= max_single_aa_prop


def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
    """Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
    total_len = sum([len(i) for i in seqs])
    return total_len >= minimum_number_of_residues


class OpenFoldDataset(torch.utils.data.Dataset):
    """
        Implements the stochastic filters applied during AlphaFold's training.
        Because samples are selected from constituent datasets randomly, the
        length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
        and filtered once at initialization.
    """

    def __init__(self,
                 datasets: Union[Sequence[OpenFoldSingleDataset], Sequence[OpenFoldSingleMultimerDataset]],
                 probabilities: Sequence[float],
                 epoch_len: int,
                 generator: torch.Generator = None,
                 _roll_at_init: bool = True,
                 ):
        self.datasets = datasets
        self.probabilities = probabilities
        self.epoch_len = epoch_len
        self.generator = generator

        self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
        if _roll_at_init:
            self.reroll()

    @staticmethod
    def deterministic_train_filter(
        cache_entry: Any,
        max_resolution: float = 9.,
        max_single_aa_prop: float = 0.8,
        *args, **kwargs
    ) -> bool:
        # Hard filters
        resolution = cache_entry.get("resolution", None)
        seqs = [cache_entry["seq"]]

        return all([resolution_filter(resolution=resolution,
                                      max_resolution=max_resolution),
                    aa_count_filter(seqs=seqs,
                                    max_single_aa_prop=max_single_aa_prop)])

    @staticmethod
    def get_stochastic_train_filter_prob(
        cache_entry: Any,
        *args, **kwargs
    ) -> float:
        # Stochastic filters
        probabilities = []

        cluster_size = cache_entry.get("cluster_size", None)
        if cluster_size is not None and cluster_size > 0:
            probabilities.append(1 / cluster_size)

        chain_length = len(cache_entry["seq"])
        probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))

        # Risk of underflow here?
        out = 1
        for p in probabilities:
            out *= p

        return out

    def looped_shuffled_dataset_idx(self, dataset_len):
        while True:
            # Uniformly shuffle each dataset's indices
            weights = [1. for _ in range(dataset_len)]
            shuf = torch.multinomial(
                torch.tensor(weights),
                num_samples=dataset_len,
                replacement=False,
                generator=self.generator,
            )
            for idx in shuf:
                yield idx

    def looped_samples(self, dataset_idx):
        max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
        dataset = self.datasets[dataset_idx]
        idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
        chain_data_cache = dataset.chain_data_cache
        while True:
            weights = []
            idx = []
            for _ in range(max_cache_len):
                candidate_idx = next(idx_iter)
                chain_id = dataset.idx_to_chain_id(candidate_idx)
                chain_data_cache_entry = chain_data_cache[chain_id]
                if not self.deterministic_train_filter(chain_data_cache_entry):
                    continue

                p = self.get_stochastic_train_filter_prob(
                    chain_data_cache_entry,
                )
                weights.append([1. - p, p])
                idx.append(candidate_idx)

            if len(weights) == 0:
                continue

            samples = torch.multinomial(
                torch.tensor(weights),
                num_samples=1,
                generator=self.generator,
            )
            samples = samples.squeeze() if samples.numel() > 1 else samples

            cache = [i for i, s in zip(idx, samples) if s]

            for datapoint_idx in cache:
                yield datapoint_idx

    def __getitem__(self, idx):
        dataset_idx, datapoint_idx = self.datapoints[idx]
        return self.datasets[dataset_idx][datapoint_idx]

    def __len__(self):
        return self.epoch_len

    def reroll(self):
        dataset_choices = torch.multinomial(
            torch.tensor(self.probabilities),
            num_samples=self.epoch_len,
            replacement=True,
            generator=self.generator,
        )
        self.datapoints = []
        for dataset_idx in dataset_choices:
            samples = self._samples[dataset_idx]
            datapoint_idx = next(samples)
            self.datapoints.append((dataset_idx, datapoint_idx))


class OpenFoldMultimerDataset(OpenFoldDataset):
    """
    Create a torch Dataset object for multimer training and 
    add filtering steps described in AlphaFold Multimer's paper:
    https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1 
    """

    def __init__(self,
                 datasets: Sequence[OpenFoldSingleMultimerDataset],
                 probabilities: Sequence[float],
                 epoch_len: int,
                 generator: torch.Generator = None,
                 _roll_at_init: bool = True
                 ):
        super(OpenFoldMultimerDataset, self).__init__(datasets=datasets,
                                                      probabilities=probabilities,
                                                      epoch_len=epoch_len,
                                                      generator=generator,
                                                      _roll_at_init=_roll_at_init)

    @staticmethod
    def deterministic_train_filter(
        cache_entry: Any,
        is_distillation: bool,
        max_resolution: float = 9.,
        max_single_aa_prop: float = 0.8,
        minimum_number_of_residues: int = 200,
        *args, **kwargs
    ) -> bool:
        """
        Implement multimer training filtering criteria described in
        https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
        """
        resolution = cache_entry.get("resolution", None)
        seqs = cache_entry["seqs"]

        return all([resolution_filter(resolution=resolution,
                                      max_resolution=max_resolution),
                    aa_count_filter(seqs=seqs,
                                    max_single_aa_prop=max_single_aa_prop),
                    (not is_distillation or all_seq_len_filter(seqs=seqs,
                                                               minimum_number_of_residues=minimum_number_of_residues))])

    @staticmethod
    def get_stochastic_train_filter_prob(
        cache_entry: Any,
        *args, **kwargs
    ) -> list:
        # Stochastic filters
        cluster_sizes = cache_entry.get("cluster_sizes")
        if cluster_sizes is not None:
            return [1 / c if c > 0 else 1 for c in cluster_sizes]

        num_chains = len(cache_entry["chain_ids"])
        return [1.] * num_chains

    def looped_samples(self, dataset_idx):
        max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
        dataset = self.datasets[dataset_idx]
        is_distillation = dataset.treat_pdb_as_distillation
        idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
        mmcif_data_cache = dataset.mmcif_data_cache
        while True:
            weights = []
            idx = []
            for _ in range(max_cache_len):
                candidate_idx = next(idx_iter)
                mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
                mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
                if not self.deterministic_train_filter(cache_entry=mmcif_data_cache_entry,
                                                       is_distillation=is_distillation):
                    continue

                chain_probs = self.get_stochastic_train_filter_prob(
                    mmcif_data_cache_entry,
                )
                weights.extend([[1. - p, p] for p in chain_probs])
                idx.extend([candidate_idx] * len(chain_probs))

            samples = torch.multinomial(
                torch.tensor(weights),
                num_samples=1,
                generator=self.generator,
            )
            samples = samples.squeeze()

            cache = [i for i, s in zip(idx, samples) if s]

            for datapoint_idx in cache:
                yield datapoint_idx


class OpenFoldBatchCollator:
    def __call__(self, prots):
        stack_fn = partial(torch.stack, dim=0)
        return dict_multimap(stack_fn, prots)


class OpenFoldDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, config, stage="train", generator=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.stage = stage
        self.generator = generator
        self._prep_batch_properties_probs()

    def _prep_batch_properties_probs(self):
        keyed_probs = []
        stage_cfg = self.config[self.stage]

        max_iters = self.config.common.max_recycling_iters

        if stage_cfg.uniform_recycling:
            recycling_probs = [
                1. / (max_iters + 1) for _ in range(max_iters + 1)
            ]
        else:
            recycling_probs = [
                0. for _ in range(max_iters + 1)
            ]
            recycling_probs[-1] = 1.

        keyed_probs.append(
            ("no_recycling_iters", recycling_probs)
        )

        keys, probs = zip(*keyed_probs)
        max_len = max([len(p) for p in probs])
        padding = [[0.] * (max_len - len(p)) for p in probs]

        self.prop_keys = keys
        self.prop_probs_tensor = torch.tensor(
            [p + pad for p, pad in zip(probs, padding)],
            dtype=torch.float32,
        )

    def _add_batch_properties(self, batch):
        gt_features = batch.pop('gt_features', None)
        samples = torch.multinomial(
            self.prop_probs_tensor,
            num_samples=1,  # 1 per row
            replacement=True,
            generator=self.generator
        )

        aatype = batch["aatype"]
        batch_dims = aatype.shape[:-2]
        recycling_dim = aatype.shape[-1]
        no_recycling = recycling_dim
        for i, key in enumerate(self.prop_keys):
            sample = int(samples[i][0])
            sample_tensor = torch.tensor(
                sample,
                device=aatype.device,
                requires_grad=False
            )
            orig_shape = sample_tensor.shape
            sample_tensor = sample_tensor.view(
                (1,) * len(batch_dims) + sample_tensor.shape + (1,)
            )
            sample_tensor = sample_tensor.expand(
                batch_dims + orig_shape + (recycling_dim,)
            )
            batch[key] = sample_tensor

            if key == "no_recycling_iters":
                no_recycling = sample

        resample_recycling = lambda t: t[..., :no_recycling + 1]
        batch = tensor_tree_map(resample_recycling, batch)
        batch['gt_features'] = gt_features

        return batch

    def __iter__(self):
        it = super().__iter__()

        def _batch_prop_gen(iterator):
            for batch in iterator:
                yield self._add_batch_properties(batch)

        return _batch_prop_gen(it)


class OpenFoldDataModule(pl.LightningDataModule):
    def __init__(self,
                 config: mlc.ConfigDict,
                 template_mmcif_dir: str,
                 max_template_date: str,
                 train_data_dir: Optional[str] = None,
                 train_alignment_dir: Optional[str] = None,
                 train_chain_data_cache_path: Optional[str] = None,
                 distillation_data_dir: Optional[str] = None,
                 distillation_alignment_dir: Optional[str] = None,
                 distillation_chain_data_cache_path: Optional[str] = None,
                 val_data_dir: Optional[str] = None,
                 val_alignment_dir: Optional[str] = None,
                 predict_data_dir: Optional[str] = None,
                 predict_alignment_dir: Optional[str] = None,
                 kalign_binary_path: str = '/usr/bin/kalign',
                 train_filter_path: Optional[str] = None,
                 distillation_filter_path: Optional[str] = None,
                 obsolete_pdbs_file_path: Optional[str] = None,
                 template_release_dates_cache_path: Optional[str] = None,
                 batch_seed: Optional[int] = None,
                 train_epoch_len: int = 50000,
                 _distillation_structure_index_path: Optional[str] = None,
                 alignment_index_path: Optional[str] = None,
                 distillation_alignment_index_path: Optional[str] = None,
                 **kwargs
                 ):
        super(OpenFoldDataModule, self).__init__()

        self.config = config
        self.template_mmcif_dir = template_mmcif_dir
        self.max_template_date = max_template_date
        self.train_data_dir = train_data_dir
        self.train_alignment_dir = train_alignment_dir
        self.train_chain_data_cache_path = train_chain_data_cache_path
        self.distillation_data_dir = distillation_data_dir
        self.distillation_alignment_dir = distillation_alignment_dir
        self.distillation_chain_data_cache_path = (
            distillation_chain_data_cache_path
        )
        self.val_data_dir = val_data_dir
        self.val_alignment_dir = val_alignment_dir
        self.predict_data_dir = predict_data_dir
        self.predict_alignment_dir = predict_alignment_dir
        self.kalign_binary_path = kalign_binary_path
        self.train_filter_path = train_filter_path
        self.distillation_filter_path = distillation_filter_path
        self.template_release_dates_cache_path = (
            template_release_dates_cache_path
        )
        self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
        self.batch_seed = batch_seed
        self.train_epoch_len = train_epoch_len

        if self.train_data_dir is None and self.predict_data_dir is None:
            raise ValueError(
                'At least one of train_data_dir or predict_data_dir must be '
                'specified'
            )

        self.training_mode = self.train_data_dir is not None

        if self.training_mode and train_alignment_dir is None:
            raise ValueError(
                'In training mode, train_alignment_dir must be specified'
            )
        elif not self.training_mode and predict_alignment_dir is None:
            raise ValueError(
                'In inference mode, predict_alignment_dir must be specified'
            )
        elif val_data_dir is not None and val_alignment_dir is None:
            raise ValueError(
                'If val_data_dir is specified, val_alignment_dir must '
                'be specified as well'
            )

        # An ad-hoc measure for our particular filesystem restrictions
        self._distillation_structure_index = None
        if _distillation_structure_index_path is not None:
            with open(_distillation_structure_index_path, "r") as fp:
                self._distillation_structure_index = json.load(fp)

        self.alignment_index = None
        if alignment_index_path is not None:
            with open(alignment_index_path, "r") as fp:
                self.alignment_index = json.load(fp)

        self.distillation_alignment_index = None
        if distillation_alignment_index_path is not None:
            with open(distillation_alignment_index_path, "r") as fp:
                self.distillation_alignment_index = json.load(fp)

    def setup(self, stage=None):
        # Most of the arguments are the same for the three datasets 
        dataset_gen = partial(OpenFoldSingleDataset,
                              template_mmcif_dir=self.template_mmcif_dir,
                              max_template_date=self.max_template_date,
                              config=self.config,
                              kalign_binary_path=self.kalign_binary_path,
                              template_release_dates_cache_path=self.template_release_dates_cache_path,
                              obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)

        if self.training_mode:
            train_dataset = dataset_gen(
                data_dir=self.train_data_dir,
                chain_data_cache_path=self.train_chain_data_cache_path,
                alignment_dir=self.train_alignment_dir,
                filter_path=self.train_filter_path,
                max_template_hits=self.config.train.max_template_hits,
                shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
                treat_pdb_as_distillation=False,
                mode="train",
                alignment_index=self.alignment_index,
            )

            distillation_dataset = None
            if self.distillation_data_dir is not None:
                distillation_dataset = dataset_gen(
                    data_dir=self.distillation_data_dir,
                    chain_data_cache_path=self.distillation_chain_data_cache_path,
                    alignment_dir=self.distillation_alignment_dir,
                    filter_path=self.distillation_filter_path,
                    max_template_hits=self.config.train.max_template_hits,
                    treat_pdb_as_distillation=True,
                    mode="train",
                    alignment_index=self.distillation_alignment_index,
                    _structure_index=self._distillation_structure_index,
                )

                d_prob = self.config.train.distillation_prob

            if distillation_dataset is not None:
                datasets = [train_dataset, distillation_dataset]
                d_prob = self.config.train.distillation_prob
                probabilities = [1. - d_prob, d_prob]
            else:
                datasets = [train_dataset]
                probabilities = [1.]

            generator = None
            if self.batch_seed is not None:
                generator = torch.Generator()
                generator = generator.manual_seed(self.batch_seed + 1)

            self.train_dataset = OpenFoldDataset(
                datasets=datasets,
                probabilities=probabilities,
                epoch_len=self.train_epoch_len,
                generator=generator,
                _roll_at_init=False,
            )

            if self.val_data_dir is not None:
                self.eval_dataset = dataset_gen(
                    data_dir=self.val_data_dir,
                    alignment_dir=self.val_alignment_dir,
                    filter_path=None,
                    max_template_hits=self.config.eval.max_template_hits,
                    mode="eval",
                )
            else:
                self.eval_dataset = None
        else:
            self.predict_dataset = dataset_gen(
                data_dir=self.predict_data_dir,
                alignment_dir=self.predict_alignment_dir,
                filter_path=None,
                max_template_hits=self.config.predict.max_template_hits,
                mode="predict",
            )

    def _gen_dataloader(self, stage=None):
        generator = None
        if self.batch_seed is not None:
            generator = torch.Generator()
            generator = generator.manual_seed(self.batch_seed)

        if stage == "train":
            dataset = self.train_dataset
            # Filter the dataset, if necessary
            dataset.reroll()
        elif stage == "eval":
            dataset = self.eval_dataset
        elif stage == "predict":
            dataset = self.predict_dataset
        else:
            raise ValueError("Invalid stage")

        batch_collator = OpenFoldBatchCollator()

        dl = OpenFoldDataLoader(
            dataset,
            config=self.config,
            stage=stage,
            generator=generator,
            batch_size=self.config.data_module.data_loaders.batch_size,
            num_workers=self.config.data_module.data_loaders.num_workers,
            collate_fn=batch_collator,
        )

        return dl

    def train_dataloader(self):
        return self._gen_dataloader("train")

    def val_dataloader(self):
        if self.eval_dataset is not None:
            return self._gen_dataloader("eval")
        return [] 

    def predict_dataloader(self):
        return self._gen_dataloader("predict")


class OpenFoldMultimerDataModule(OpenFoldDataModule):
    """
    Create a datamodule specifically for multimer training

    Compared to OpenFoldDataModule, OpenFoldMultimerDataModule
    requires mmcif_data_cache_path which is the product of 
    scripts/generate_mmcif_cache.py mmcif_data_cache_path should be 
    a file that record what chain(s) each mmcif file has 
    """

    def __init__(self, config: mlc.ConfigDict,
                 template_mmcif_dir: str, max_template_date: str,
                 train_data_dir: Optional[str] = None,
                 train_mmcif_data_cache_path: Optional[str] = None,
                 val_mmcif_data_cache_path: Optional[str] = None,
                 **kwargs):
        super(OpenFoldMultimerDataModule, self).__init__(config,
                                                         template_mmcif_dir,
                                                         max_template_date,
                                                         train_data_dir,
                                                         **kwargs)

        self.train_mmcif_data_cache_path = train_mmcif_data_cache_path
        self.training_mode = self.train_data_dir is not None
        self.val_mmcif_data_cache_path = val_mmcif_data_cache_path

    def setup(self, setup=None):
        # Most of the arguments are the same for the three datasets 
        dataset_gen = partial(OpenFoldSingleMultimerDataset,
                              template_mmcif_dir=self.template_mmcif_dir,
                              max_template_date=self.max_template_date,
                              config=self.config,
                              kalign_binary_path=self.kalign_binary_path,
                              template_release_dates_cache_path=self.template_release_dates_cache_path,
                              obsolete_pdbs_file_path=self.obsolete_pdbs_file_path)

        if self.training_mode:
            train_dataset = dataset_gen(
                data_dir=self.train_data_dir,
                mmcif_data_cache_path=self.train_mmcif_data_cache_path,
                alignment_dir=self.train_alignment_dir,
                filter_path=self.train_filter_path,
                max_template_hits=self.config.train.max_template_hits,
                shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
                treat_pdb_as_distillation=False,
                mode="train",
                alignment_index=self.alignment_index,
            )

            distillation_dataset = None
            if self.distillation_data_dir is not None:
                distillation_dataset = dataset_gen(
                    data_dir=self.distillation_data_dir,
                    alignment_dir=self.distillation_alignment_dir,
                    filter_path=self.distillation_filter_path,
                    max_template_hits=self.config.train.max_template_hits,
                    treat_pdb_as_distillation=True,
                    mode="train",
                    alignment_index=self.distillation_alignment_index,
                    _structure_index=self._distillation_structure_index,
                )

                d_prob = self.config.train.distillation_prob

            if distillation_dataset is not None:
                datasets = [train_dataset, distillation_dataset]
                d_prob = self.config.train.distillation_prob
                probabilities = [1. - d_prob, d_prob]
            else:
                datasets = [train_dataset]
                probabilities = [1.]

            generator = None
            if self.batch_seed is not None:
                generator = torch.Generator()
                generator = generator.manual_seed(self.batch_seed + 1)

            self.train_dataset = OpenFoldMultimerDataset(
                datasets=datasets,
                probabilities=probabilities,
                epoch_len=self.train_epoch_len,
                generator=generator,
                _roll_at_init=True,
            )

            if self.val_data_dir is not None:
                self.eval_dataset = dataset_gen(
                    data_dir=self.val_data_dir,
                    alignment_dir=self.val_alignment_dir,
                    mmcif_data_cache_path=self.val_mmcif_data_cache_path,
                    filter_path=None,
                    max_template_hits=self.config.eval.max_template_hits,
                    mode="eval",
                )
            else:
                self.eval_dataset = None
        else:
            self.predict_dataset = dataset_gen(
                data_dir=self.predict_data_dir,
                alignment_dir=self.predict_alignment_dir,
                filter_path=None,
                max_template_hits=self.config.predict.max_template_hits,
                mode="predict",
            )


class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, batch_path):
        with open(batch_path, "rb") as f:
            self.batch = pickle.load(f)

    def __getitem__(self, idx):
        return copy.deepcopy(self.batch)

    def __len__(self):
        return 1000


class DummyDataLoader(pl.LightningDataModule):
    def __init__(self, batch_path):
        super().__init__()
        self.dataset = DummyDataset(batch_path)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset)
